{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Online CNMF-E\n",
    "\n",
    "This demo shows an example of doing online analysis on one-photon data. We compare offline and online approaches. The dataset used is courtesy of the Miniscope project."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bokeh.plotting as bpl\n",
    "import holoviews as hv\n",
    "from IPython import get_ipython\n",
    "import logging\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from time import time\n",
    "\n",
    "import caiman as cm\n",
    "from caiman.source_extraction import cnmf as cnmf\n",
    "from caiman.motion_correction import MotionCorrect\n",
    "from caiman.utils.utils import download_demo\n",
    "from caiman.utils.visualization import nb_inspect_correlation_pnr\n",
    "\n",
    "try:\n",
    "    if __IPYTHON__:\n",
    "        get_ipython().run_line_magic('load_ext', 'autoreload')\n",
    "        get_ipython().run_line_magic('autoreload', '2')\n",
    "except NameError:\n",
    "    pass\n",
    "\n",
    "logfile = None # Replace with a path if you want to log to a file\n",
    "logger = logging.getLogger('caiman')\n",
    "# Set to logging.INFO if you want much output, potentially much more output\n",
    "logger.setLevel(logging.WARNING)\n",
    "logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')\n",
    "if logfile is not None:\n",
    "    handler = logging.FileHandler(logfile)\n",
    "else:\n",
    "    handler = logging.StreamHandler()\n",
    "handler.setFormatter(logfmt)\n",
    "logger.addHandler(handler)\n",
    "\n",
    "bpl.output_notebook()\n",
    "hv.notebook_extension('bokeh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select file(s) to be processed\n",
    "The `download_demo` function will download the specific file for you and return the complete path to the file which will be stored in your `caiman_data` directory. If you adapt this demo for your data make sure to pass the complete path to your file(s). Remember to pass the `fnames` variable as a list. Note that the memory requirement of the offline CNMF-E algorithm are much higher compared to the standard CNMF algorithm. One of the benefits of the online approach is the reduced memory requirements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "movie_ind = 0   # 0 for avi, 1 for the tif (in case avi loading troubles)\n",
    "fnames = ['msCam13.avi', 'data_endoscope.tif']  # filename to be processed\n",
    "fnames = [download_demo(fnames[movie_ind])] "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (optional) View movie\n",
    "If you do not wish to view (e.g., your movie is too large or you are in certain Jupyter environments), set `play_movies` to `False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Press `q` to close the movie\n",
    "play_movies = True\n",
    "temporal_downsampling = 0.2\n",
    "if play_movies:\n",
    "    movie_orig = cm.load(fnames[0])\n",
    "    movie_orig.resize(1, 1, temporal_downsampling).play(gain = 0.6, \n",
    "                                                        q_max=99.9, \n",
    "                                                        q_min=1,  \n",
    "                                                        fr=20, \n",
    "                                                        magnification=1, \n",
    "                                                        plot_text=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch (offline) approach\n",
    "\n",
    "We start with motion correction and then proceed with the source extraction using the CNMF-E algorithm. For a detailed 1p demo check `demo_pipeline_cnmfE.ipynb`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# motion correction parameters\n",
    "motion_correct = True            # flag for performing motion correction\n",
    "pw_rigid = False                 # flag for performing piecewise-rigid motion correction (otherwise just rigid)\n",
    "gSig_filt = (7, 7)               # size of high pass spatial filtering, used in 1p data\n",
    "max_shifts = (20, 20)            # maximum allowed rigid shift\n",
    "border_nan = 'copy'              # replicate values along the boundaries\n",
    "\n",
    "mc_dict = {\n",
    "    'pw_rigid': pw_rigid,\n",
    "    'max_shifts': max_shifts,\n",
    "    'gSig_filt': gSig_filt,\n",
    "    'border_nan': border_nan\n",
    "}\n",
    "\n",
    "opts = cnmf.params.CNMFParams(params_dict=mc_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)\n",
    "if 'dview' in locals():\n",
    "    cm.stop_server(dview=dview)\n",
    "c, dview, n_processes = cm.cluster.setup_cluster(backend='multiprocessing', \n",
    "                                                 n_processes=None,  # None\n",
    "                                                 single_thread=False)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize motion correction estimator with the parameters, and run to fit data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "mc = MotionCorrect(fnames, dview=dview, **opts.get_group('motion')) \n",
    "mc.motion_correct(save_movie=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### inspect motion correction results\n",
    "This will load into memory so can take a while."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inspect_results = True\n",
    "if inspect_results and play_movies:\n",
    "    temporal_downsampling = 0.5\n",
    "    cm.concatenate((cm.load(fnames).resize(1, 1, temporal_downsampling), \n",
    "                    cm.load(mc.mmap_file).resize(1, 1, temporal_downsampling)), axis=2).play(fr=30,\n",
    "                                                         magnification=1,\n",
    "                                                         gain=0.6,\n",
    "                                                         plot_text=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot shifts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(); \n",
    "plt.plot(mc.shifts_rig); \n",
    "plt.legend(['x-shifts', 'y-shifts']); "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The motion correction results look good. We then proceed with memory mapping and checking the correlation/pnr images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fname_new = cm.save_memmap(mc.mmap_file, base_name='memmap_', order='C',\n",
    "                           border_to_0=0, dview=dview)\n",
    "Yr, dims, T = cm.load_memmap(fname_new)\n",
    "images = Yr.T.reshape((T,) + dims, order='F')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspect correlation and PNR images to set relevant thresholds\n",
    "First, extract correlation and pnr images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gSig = (6, 6)\n",
    "cn_filter, pnr = cm.summary_images.correlation_pnr(images[::max(T//1000, 1)], \n",
    "                                                   gSig=gSig[0], \n",
    "                                                   swap_dim=False) # change swap dim if output looks weird, it is a problem with tiffile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inspect the summary images so you can set parameters in the following section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_inspect_correlation_pnr(cn_filter, pnr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set parameters for source extraction\n",
    "From the images above we select `min_pnr = 10` and `min_corr = 0.8`. We pass these alongside the other parameters needed for offline 1p processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_pnr = 10  # estimate from corr_pnr above\n",
    "min_corr = 0.8 # estimate from corr_pnr above\n",
    "rf = 48                                        # half size of each patch\n",
    "stride = 8                                     # amount of overlap between patches     \n",
    "ssub = 1                                       # spatial downsampling factor   \n",
    "decay_time = 0.4                               # length of typical transient (in seconds)\n",
    "fr = 10                                        # imaging rate (Hz) \n",
    "gSig = (6, 6)                                  # expected half size of neurons\n",
    "gSiz = (15, 15)                                # half size for neuron bounding box   \n",
    "p = 0                                          # order of AR indicator dynamics\n",
    "min_SNR = 1.5                                  # minimum SNR for accepting new components\n",
    "rval_thr = 0.85                                # correlation threshold for new component inclusion\n",
    "merge_thr = 0.65                               # merging threshold\n",
    "K = None                                       # initial number of components\n",
    "\n",
    "cnmfe_dict = {'fnames': fnames,\n",
    "              'fr': fr,\n",
    "              'decay_time': decay_time,\n",
    "              'method_init': 'corr_pnr',\n",
    "              'gSig': gSig,\n",
    "              'gSiz': gSiz,\n",
    "              'rf': rf,\n",
    "              'stride': stride,\n",
    "              'p': p,\n",
    "              'nb': 0,\n",
    "              'ssub': ssub,\n",
    "              'min_SNR': min_SNR,\n",
    "              'min_pnr': min_pnr,\n",
    "              'min_corr': min_corr,\n",
    "              'bas_nonneg': False,\n",
    "              'center_psf': True,\n",
    "              'rval_thr': rval_thr,\n",
    "              'only_init': True,\n",
    "              'merge_thr': merge_thr,\n",
    "              'K': K}\n",
    "opts.change_params(cnmfe_dict);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = -time()\n",
    "cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, params=opts)\n",
    "cnm.fit(images)\n",
    "t1 += time()\n",
    "print(f\"Elapsed time: {t1: 0.2f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.plot_contours_nb(img=pnr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.hv_view_components(img=cn_filter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Show a movie with the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if play_movies:\n",
    "    cnm.estimates.play_movie(images, magnification=0.75, include_bck=False);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Online Processing\n",
    "Now try the online approach. The idea behind the online algorithm is simple:\n",
    "- First initialize the estimates by running the batch (offline) algorithm in small subset.\n",
    "- Then process each frame as it arrives. The processing consists of:\n",
    "    * Motion correct the new frame\n",
    "    * Extract the activity of existing neurons at this frame, and neuropil\n",
    "    * Search for new neurons that appear in this frame and have not been detected earlier.\n",
    "- Periodically update shapes of existing neurons and background model.\n",
    "\n",
    "## Setup additional parameters for online processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "online_opts = deepcopy(cnm.params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf = 48                                        # half size of patch (used only during initialization)\n",
    "stride = 8                                     # overlap between patches (used only during initialization) \n",
    "ssub = 1                                       # spatial downsampling factor (during initialization)\n",
    "ds_factor = 2*ssub                             # spatial downsampling factor (during online processing)         \n",
    "ssub_B = 4                                     # background downsampling factor (use that for faster processing)\n",
    "gSig = (10//ds_factor, 10//ds_factor)          # expected half size of neurons\n",
    "gSiz = (22//ds_factor, 22//ds_factor)\n",
    "sniper_mode = False                            # flag using a CNN to detect new neurons (o/w space correlation is used)\n",
    "init_batch = 300                               # number of frames for initialization (presumably from the first file)\n",
    "expected_comps = 500                           # maximum number of expected components used for memory pre-allocation (exaggerate here)\n",
    "dist_shape_update = False                      # flag for updating shapes in a distributed way\n",
    "min_num_trial = 5                              # number of candidate components per frame     \n",
    "K = None                                       # initial number of components\n",
    "epochs = 2                                     # number of passes over the data\n",
    "show_movie = False                             # show the movie with the results as the data gets processed\n",
    "use_corr_img = True                            # flag for using the corr*pnr image when searching for new neurons (otherwise residual)\n",
    "\n",
    "online_dict = {'epochs': epochs,\n",
    "               'nb': 0,\n",
    "               'ssub': ssub,\n",
    "               'ssub_B': ssub_B,\n",
    "               'ds_factor': ds_factor,                                   # ds_factor >= ssub should hold\n",
    "               'gSig': gSig,\n",
    "               'gSiz': gSiz,\n",
    "               'gSig_filt': (3, 3),\n",
    "               'min_corr': min_corr,\n",
    "               'bas_nonneg': False,\n",
    "               'center_psf': True,\n",
    "               'max_shifts_online': 20,\n",
    "               'rval_thr': rval_thr,\n",
    "               'motion_correct': True,\n",
    "               'init_batch': init_batch,\n",
    "               'only_init': True,\n",
    "               'init_method': 'cnmf',\n",
    "               'normalize_init': False,\n",
    "               'update_freq': 200,\n",
    "               'expected_comps': expected_comps,\n",
    "               'sniper_mode': sniper_mode,                               # set to False for 1p data       \n",
    "               'dist_shape_update' : dist_shape_update,\n",
    "               'min_num_trial': min_num_trial,\n",
    "               'epochs': epochs,\n",
    "               'use_corr_img': use_corr_img,\n",
    "               'show_movie': show_movie}\n",
    "online_opts.change_params(online_dict);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize estimator\n",
    "cnm_online = cnmf.online_cnmf.OnACID(params=online_opts, dview=dview)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run online estimator\n",
    "cnm_online.fit_online();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#images = cm.load(fnames[0], subindices=slice(0,1000))\n",
    "#Cn, pnr = cm.summary_images.correlation_pnr(images[::1], gSig=gSig[0], swap_dim=False) # change swap dim if output looks weird, it is a problem with tiffile\n",
    "cnm_online.estimates.nb_view_components(img=pnr, denoised_color='red');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm_online.estimates.plot_contours_nb(img=pnr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot timing\n",
    "The plot below shows the time spent on each part of the algorithm (motion correction, tracking of current components, detect new components, update shapes) for each frame. Note that if you displayed a movie while processing the data (`show_movie=True`) the time required to generate this movie will be included here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_cumulative = True\n",
    "#if show_cumulative:\n",
    "T_init = np.array([cnm_online.t_init] + [0]*(epochs*T-1))\n",
    "T_motion = 1e3*np.array([0]*init_batch + cnm_online.t_motion)/1e3\n",
    "T_detect = 1e3*np.array([0]*init_batch + cnm_online.t_detect)/1e3\n",
    "T_shapes = 1e3*np.array([0]*init_batch + cnm_online.t_shapes)/1e3\n",
    "T_online = 1e3*np.array([0]*init_batch + cnm_online.t_online)/1e3 - T_motion - T_detect - T_shapes\n",
    "plt.figure()\n",
    "plt.stackplot(np.arange(len(T_motion)), np.cumsum(T_init), np.cumsum(T_motion), np.cumsum(T_online), np.cumsum(T_detect), np.cumsum(T_shapes))\n",
    "plt.legend(labels=['init', 'motion', 'process', 'detect', 'shapes'], loc=2)\n",
    "for i in range(epochs - 1):\n",
    "    plt.plot([(i+1)*T, (i+1)*T], [0, np.array(cnm_online.t_online).sum()+cnm_online.t_init], '--k')\n",
    "plt.title('Processing time allocation')\n",
    "plt.xlabel('Frame #')\n",
    "plt.ylabel('Processing time [ms]')\n",
    "#plt.ylim([0, 1.2e3*np.percentile(np.array(cnm_online.t_online), 90)]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if play_movies:\n",
    "    cnm_online.estimates.play_movie(imgs=images, magnification=0.75, include_bck=False);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean up and compare two approaches\n",
    "\n",
    "Even though the online algorithm screens any new components, we can still perform the quality tests to filter out any false positive components. To do that, we first need to apply the inferred shifts to the original data in order to have the whole registered dataset in memory mapped form."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if online_opts.online['motion_correct']:\n",
    "    shifts = cnm_online.estimates.shifts[-cnm_online.estimates.C.shape[-1]:]\n",
    "    if not opts.motion['pw_rigid']:\n",
    "        memmap_file = cm.motion_correction.apply_shift_online(images, shifts,\n",
    "                                                              save_base_name='MC')\n",
    "    else:\n",
    "        mc = MotionCorrect(fnames, dview=dview, **online_opts.get_group('motion'))\n",
    "\n",
    "        mc.y_shifts_els = [[sx[0] for sx in sh] for sh in shifts]\n",
    "        mc.x_shifts_els = [[sx[1] for sx in sh] for sh in shifts]\n",
    "        memmap_file = mc.apply_shifts_movie(fnames, rigid_shifts=False,\n",
    "                                            save_memmap=True,\n",
    "                                            save_base_name='MC')\n",
    "else:  # To do: apply non-rigid shifts on the fly\n",
    "    memmap_file = images.save(fnames[0][:-4] + 'mmap')\n",
    "cnm_online.mmap_file = memmap_file\n",
    "Yr_online, dims, T = cm.load_memmap(memmap_file)\n",
    "\n",
    "#cnm_online.estimates.dview=dview\n",
    "#cnm_online.estimates.compute_residuals(Yr=Yr_online)\n",
    "images_online = np.reshape(Yr_online.T, [T] + list(dims), order='F')\n",
    "min_SNR = 2  # peak SNR for accepted components (if above this, accept)\n",
    "rval_thr = 0.85  # space correlation threshold (if above this, accept)\n",
    "use_cnn = False # use the CNN classifier\n",
    "cnm_online.params.change_params({'min_SNR': min_SNR,\n",
    "                                'rval_thr': rval_thr,\n",
    "                                'use_cnn': use_cnn})\n",
    "\n",
    "cnm_online.estimates.evaluate_components(images_online, cnm_online.params, dview=dview)\n",
    "cnm_online.estimates.Cn = pnr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm_online.estimates.plot_contours_nb(img=pnr, idx=cnm_online.estimates.idx_components)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm_online.estimates.hv_view_components(img=pnr, idx=cnm_online.estimates.idx_components,\n",
    "                                        denoised_color='red')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm_online.estimates.hv_view_components(img=pnr, \n",
    "                                        idx=cnm_online.estimates.idx_components_bad,\n",
    "                                        denoised_color='red')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Difference in inferred shifts\n",
    "\n",
    "Accurate motion correction is important for the online algorithm. Below we plot the difference in the estimated shifts between the two approaches. Note that the online shifts have been rescaled by a factor of `ds_factor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.array(mc.shifts_rig) - ds_factor*np.array(cnm_online.estimates.shifts[:1000]));\n",
    "plt.legend(['x-shifts', 'y-shifts']);\n",
    "plt.title('Difference between offline and online shifts')\n",
    "plt.xlabel('Frame #')\n",
    "plt.ylabel('pixels')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Constant shifts in the FOV will not significantly affect the results. What is most important is deviatons. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.std(np.array(mc.shifts_rig) - ds_factor*np.array(cnm_online.estimates.shifts[:1000]), axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The standard deviation is at a subpixel level (although it can still be significant). The high degree of similarity can also be seen from the correlation between the shifts of the two approaches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.corrcoef(np.array(mc.shifts_rig).T, np.array(cnm_online.estimates.shifts[:1000]).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "caiman_pytorch2",
   "language": "python",
   "name": "caiman_pytorch2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
